import pandas as pd
import numpy as np
import pickle as pickle
import os as os

## example gillespie run to simulate che-1 HD model
# caution: run time for 24 hours simulation time is approx 48 hours on 1 core

class CHE1_gils:

    def __init__(self,param_file):

        # number of species and reactions per cell
        self.N_S=9
        self.N_R=15

        # species labels and indeces
        self.species_lbl=['O','OC','OHC','OH','H','C','M','D','DC']
        self.species_ind={'O':0,'OC':1,'OHC':2,'OH':3,'H':4,'C':5,'M':6,'D':7,'DC':8}

        self.param_list=self.read_param_file(param_file)
        print(self.param_list)

        if self.param_list['propensity_function']=='basic':
            self.calculate_propensity=self.calculate_propensity_basic
            self.dN=self.construct_dN()
        print(self.dN)

    def read_param_file(self,param_file):
        param_list={}
        with open(param_file) as f:
            lines = f.readlines()
            for l in lines:
                try:
                    (p,val)=l.split(":")
                    try:
                        # save value as an integer
                        param_list[p]=np.float(val)
                    except ValueError:
                        # otherwise save as text with \n and \t stripped
                        param_list[p]=val.strip()
                except ValueError:
                    pass
        return (param_list)

    def S_ind(self,S_lbl):
        return (self.species_ind[S_lbl])

    def construct_dN(self):
        # construct the stoichiometry matrix dN
        dN=np.zeros((self.N_R,self.N_S))

        print(self.N_R+0)
        # r0: O + C --> OC
        dN[0, self.S_ind('O') ] = -1
        dN[0, self.S_ind('OC') ] = +1
        dN[0, self.S_ind('C') ] = -1
        # r1: OC --> O + C
        dN[1, self.S_ind('O') ] = +1
        dN[1, self.S_ind('OC') ] = -1
        dN[1, self.S_ind('C') ] = +1
        # r2: O + H --> OH
        dN[2, self.S_ind('O') ] = -1
        dN[2, self.S_ind('OH') ] = +1
        dN[2, self.S_ind('H') ] = -1
        # r3: OH --> O + H
        dN[3, self.S_ind('O') ] = +1
        dN[3, self.S_ind('OH') ] = -1
        dN[3, self.S_ind('H') ] = +1
        # r4: OH + C --> OHC
        dN[4, self.S_ind('OH') ] = -1
        dN[4, self.S_ind('OHC') ] = +1
        dN[4, self.S_ind('C') ] = -1
        # r5: OHC --> OH + C
        dN[5, self.S_ind('OH') ] = +1
        dN[5, self.S_ind('OHC') ] = -1
        dN[5, self.S_ind('C') ] = +1
        # r6: OC + H --> OHC
        dN[6, self.S_ind('OC') ] = -1
        dN[6, self.S_ind('OHC') ] = +1
        dN[6, self.S_ind('H') ] = -1
        # r7: OHC --> OC + H
        dN[7, self.S_ind('OC') ] = +1
        dN[7, self.S_ind('OHC') ] = -1
        dN[7, self.S_ind('H') ] = +1
        # r8: OHC --> OHC + M
        dN[8, self.S_ind('M') ] = +1
        # r9: OC --> OC + M
        dN[9, self.S_ind('M') ] = +1
        # r10: M -> M + C
        dN[10, self.S_ind('C') ] = +1
        # r11: M -> 0
        dN[11, self.S_ind('M') ] = -1
        # r12: C -> 0
        dN[12, self.S_ind('C') ] = -1
        # r13: D + C --> DC
        dN[13, self.S_ind('D') ] = -1
        dN[13, self.S_ind('DC') ] = +1
        dN[13, self.S_ind('C') ] = -1
        # r14: DC --> D + C
        dN[14, self.S_ind('D') ] = +1
        dN[14, self.S_ind('DC') ] = -1
        dN[14, self.S_ind('C') ] = +1

        return(dN)

    def initialize_reactants(self,init_data):
        N=np.zeros(self.N_S,dtype=int)
        init_list=init_data
        for lbl in init_list.keys():
            # add reactant particle number to array of reactant numbers
            N[self.S_ind(lbl)] = init_list[lbl]
        return(N)

    def calculate_propensity_basic(self,N):
        a=np.zeros(self.N_R,dtype=float)
        i=0
        # r0: O + C --> OC
        a[0] = self.param_list["fO"] * N[ self.S_ind('O') ] * N[ self.S_ind('C')]
        # r1: OC --> O + C
        a[1] = self.param_list["bO"] * N[ self.S_ind('OC') ]
        # r2: O + H --> OH
        a[2] = self.param_list['fO'] * N[ self.S_ind('O') ] * N[ self.S_ind('H') ]
        # r3: OH --> O + H
        a[3] = self.param_list['bH'] * N[ self.S_ind('OH') ]
        # r4: OH + C --> OHC
        a[4] = self.param_list['fO'] * N[ self.S_ind('OH') ] * N[ self.S_ind('C') ]
        # r5: OHC --> OH + C
        a[5] = self.param_list['bOs'] * N[ self.S_ind('OHC') ]
        # r6: OC + H --> OHC
        a[6] = self.param_list['fO'] * N[ self.S_ind('OC') ] * N[ self.S_ind('H') ]
        # r7: OHC --> OC + H
        a[7] = self.param_list['bHs'] * N[ self.S_ind('OHC') ]
        # r8: OHC --> OHC + M
        a[8] = self.param_list['fM'] * N[ self.S_ind('OHC') ]
        # r9: OC --> OC + M
        a[9] = self.param_list['fM'] * N[ self.S_ind('OC') ]
        # r10: M -> M + C
        a[10] = self.param_list['fC'] * N[ self.S_ind('M') ]
        # r11: M -> 0
        a[11] = self.param_list['bM'] * N[ self.S_ind('M') ]
        # r12: C -> 0
        a[12] = self.param_list['bC'] * N[ self.S_ind('C') ]
        # r13: D + C --> DC
        a[13] = self.param_list["fO"] * N[ self.S_ind('D') ] * N[ self.S_ind('C')]
        # r14: DC --> D + C
        a[14] = self.param_list["bD"] * N[ self.S_ind('DC') ]

        return(a)


    def print_reactions(self,a,i):
        print("R[%d]:"%i,end=' ')
        r=np.where(self.dN[i,:]==-1)[0]
        if len(r)>0:
            if r[0]>=self.N_S:
                ind=r[0]-self.N_S
                print(self.species_lbl[ind]+'_p',end='')
            else:
                ind=r[0]
                print(self.species_lbl[ind]+'_a',end='')
        else:
            print('0',end='')
        print("-->",end=''),
        r=np.where(self.dN[i,:]==1)[0]
        if len(r)>0:
            if r[0]>=self.N_S:
                ind=r[0]-self.N_S
                print(self.species_lbl[ind]+'_p',end='')
            else:
                ind=r[0]
                print(self.species_lbl[ind]+'_a',end='')
        else:
            print('0',end='')
        print(", a=%f"%a[i])

    def print_molecules(self,t,N):
        print("t=%f"%t,end=' ')
        for i in [0]:
            for j in range(0,self.N_S):
                print(self.species_lbl[j]+":%d"%N[i*self.N_S+j],end=' ')
        print()

    def init_save_data(self,t,N,species_list):
        sim_data=[]
        # first, save time
        sim_data.append(np.array([t]))
        for S_lbl in species_list:
            N_save = N[ [self.S_ind(S_lbl)] ]
            sim_data.append( np.array([N_save]) )
        return(sim_data)

    def add_save_data(self,sim_data,t,N,species_list):
        # first, save time
        sim_data[0]=np.hstack((sim_data[0],t))
        c=1
        for S_lbl in species_list:
            N_save = N[ [self.S_ind(S_lbl)] ]
            sim_data[c]=np.vstack((sim_data[c],N_save))
            c=c+1
        return(sim_data)

    def propagate_reactions(self,t,N,T_sim,DT_out,save_species_list,generate_moments=False):

        time_insteadofmoments = [0] # time instead of moments???
        # calculate the time when the next data is saved
        t_out_new = np.ceil(t/DT_out)*DT_out

        # initialize stored data
        trajectory_data=self.init_save_data(t,N,save_species_list)

        # moments
        if generate_moments:
            T=0
            N_m=np.zeros((2,2*self.N_S),dtype=float)

        t_count = 1
        # while time is smaller than t_sim
        while t<T_sim:
            if N[self.S_ind('C')] == 0 and N[self.S_ind('DC')] == 0:
                break

            # calculate propensity function
            a=self.calculate_propensity(N)

            # as well as cumulative propensity
            A=a.cumsum()
            # and total propensity
            if A[-1] < 0:
                A[-1] = 0.01

            A_tot=A[-1]

            # draw time to next reaction
            dt=np.random.exponential(1/A_tot,1)[0]   #hierrrrr

            # draw reaction
            # first, draw random number r
            r=A_tot*np.random.rand()
            # find which reaction occurs according to cumulative propensity
            cont=True
            i=0
            while cont:
                if r<=A[i]:
                    react_ind=i
                    cont=False
                i=i+1

            if generate_moments:
                T += dt
                N_m[0,:] += dt*N
                N_m[1,:] += dt*N**2
            # add time
            t=t+dt

            # save data
            if t>t_out_new:
                trajectory_data=self.add_save_data(trajectory_data,t_out_new,N,save_species_list)
                t_out_new += DT_out
                time_insteadofmoments.append(t)

            if t > 12*3600:
                self.param_list['bC'] = 0.0064

            if t > 24*3600:
                self.param_list['bC'] = 0.00023

            if round(t/60/60) in np.linspace(1,36,36) and round(t/60/60) == t_count:
                print(round(t/60/60))
                t_count += 1

            # change number of reactants according to reaction
            N=N+self.dN[react_ind,:]

        save_data={'trajectory_data':trajectory_data}

        if generate_moments:
            N_m=N_m/T
            save_data['moment_data']=N_m

        return(time_insteadofmoments,N,save_data)


    def run_full_simulation(self,N,save_species_list):

        t=0
        run_data=[]

        (t,N,sim_data) = self.propagate_reactions(t,N,36*3600,1,save_species_list)
        run_data.append(sim_data)

        return(run_data,t)


che1sim = CHE1_gils("params_che1.txt")

N = che1sim.initialize_reactants({'O': 1, 'OC': 0, 'OHC': 0,'OH': 0,'H': 5000, 'C': 900, 'M': 7,'D': 500,'DC': 0})
testrun,t = che1sim.run_full_simulation(N,["OHC","OC","OH","H","C","M","D","DC"])

df = pd.DataFrame(list(zip(testrun[0]["trajectory_data"][1],
                          testrun[0]["trajectory_data"][2],
                          testrun[0]["trajectory_data"][3],
                          testrun[0]["trajectory_data"][4],
                          testrun[0]["trajectory_data"][5],
                          testrun[0]["trajectory_data"][6],
                          testrun[0]["trajectory_data"][7],
                          testrun[0]["trajectory_data"][8])),
               columns =["OHC","OC","OH","H","C","M","D","DC"])
df.to_csv("210712_che1_MODEL1.csv")
